"""
Main program to train GCN based on SDMP
"""

# pylint: disable=anomalous-backslash-in-string
# pylint: disable=invalid-name
# pylint: disable=missing-function-docstring
import os
import sys
from time import time
import warnings
import argparse
import shutil
import pickle

import numpy as np
import torch
# import torch.nn.functional as F

from train import distill_train as train
from model import MLP
from graph_dict import SDMP
from data_utils import load_data, SDMPDataPre
from data_utils import PlainLoader as MyLoader
from data_utils import PlainListLoader
from utils import Logger, export_train_conf, evaluate
from utils import load_mlp_conf_with_default, load_sdmp_conf_with_default

warnings.filterwarnings('ignore')
os.environ["CURL_CA_BUNDLE"] = ""

## Parse command line parameters
_parser = argparse.ArgumentParser(
    description="Argument of the main.")

_parser.add_argument("-c", "--config", default="./config/cora/cora_test_SDMP0.yml",
                     type=str, help="Path to the configuration file. ")
_parser.add_argument("-n", "--dict", default="",
                     type=str, help="Path to the output folder of a dictionary learning result. ")
_parser.add_argument("-d", "--data", default="dataset", type=str,
                     help="Path the root datafolder. ")
_parser.add_argument("-r", "--result", default="result/tmp", type=str,
                     help="Path to the result folder. ")
_parser.add_argument("-g", "--gnn", default="result/cora/SAGE", type=str,
                     help="Path to the target GNN folder. ")
_parser.add_argument("-i", "--device", default="cuda:6", type=str,
                     help="Device name to run pytorch code. ")
_parser.add_argument("-t", "--trials", default=1, type=int,
                     help="Number of trials to run. ")

ARGS_GLOBAL = _parser.parse_args()


DEVICE = ARGS_GLOBAL.device
DATA_ROOT_FOLDER = ARGS_GLOBAL.data
CONF_PATH = ARGS_GLOBAL.config
RES_FOLDER = ARGS_GLOBAL.result
TARGET_GNN_FOLDER = ARGS_GLOBAL.gnn

SDMP_CONF_PATH = os.path.join(ARGS_GLOBAL.dict, "conf.yml")
SDMP_THETA_PATH = os.path.join(ARGS_GLOBAL.dict, "ThetaT.pkl")
SDMP_H_PATH = os.path.join(ARGS_GLOBAL.dict, "h_model_stat.pkl")
SDMP_LOG_PATH = os.path.join(ARGS_GLOBAL.dict, "log.pkl")
SDMP_LOGTXT_PATH = os.path.join(ARGS_GLOBAL.dict, "log.txt")
SDMP_TARGET_SPLIT_PATH = os.path.join(ARGS_GLOBAL.dict, "GNN_data_split_seed.txt")

train_conf = load_mlp_conf_with_default(CONF_PATH)
train_conf["patience"] = train_conf.get("patience", 1000000) # if not specified, no early stopping

sdmp_conf = load_sdmp_conf_with_default(SDMP_CONF_PATH)

if not os.path.exists(RES_FOLDER):
    os.makedirs(RES_FOLDER)

DATA_FOLDER = os.path.join(DATA_ROOT_FOLDER, train_conf["name"])

MODEL_FOLDER = os.path.join(RES_FOLDER, "models")
if not os.path.exists(MODEL_FOLDER):
    os.makedirs(MODEL_FOLDER)

sys.stdout = Logger(os.path.join(RES_FOLDER, "log.txt"))

for k in vars(ARGS_GLOBAL):
    print(k, getattr(ARGS_GLOBAL, k))

print(train_conf)

GNN_MODEL_PATH = os.path.join(TARGET_GNN_FOLDER, sdmp_conf['target_h_model_path'])
GNN_TEACHER_PATH = os.path.join(TARGET_GNN_FOLDER, train_conf['teacher_path'])
GNN_CONF_PATH = os.path.join(TARGET_GNN_FOLDER, sdmp_conf['target_h_model_conf_path'])
GNN_ACC_PATH = os.path.join(TARGET_GNN_FOLDER, sdmp_conf['target_h_model_metric_path'])
GNN_DATA_SPLIT_PATH = os.path.join(TARGET_GNN_FOLDER, "data_split_seed.txt")

## Preprocessing
# Data loading
with open(SDMP_TARGET_SPLIT_PATH, "r") as fin:
    sdmp_target_split_seed = int(eval(fin.read()))

with open(GNN_DATA_SPLIT_PATH, "r") as fin:
    gnn_data_split_seed = int(eval(fin.read()))

assert sdmp_target_split_seed == gnn_data_split_seed
print(f"Experiment with data split seed {sdmp_target_split_seed}")
g = load_data(train_conf["name"], seed=sdmp_target_split_seed)
if train_conf['lam_distill'] > 0:
    print("Added distillation loss, loading teacher embedidng...")
    teacher_emb = torch.from_numpy(np.load(GNN_TEACHER_PATH)['arr_0'])
else:
    teacher_emb = None
# Preprocessing the features
preprocesser = SDMPDataPre(train_conf["name"], sdmp_conf["feature_normalize"],
                           sdmp_conf["target_h_mode"],
                           GNN_CONF_PATH, GNN_MODEL_PATH, sdmp_conf["target_h_model"], 
                           sdmp_conf["h_init_theta_mode"], sdmp_conf["h_init_theta_k"],
                           sdmp_conf["h_init_theta_k_fanout"],
                           sdmp_conf["theta_cand_mode"], sdmp_conf["theta_cand_k2"],
                           sdmp_conf["theta_cand_k1"], sdmp_conf["theta_cand_fanout"],
                           sdmp_conf["theta_cand_add_self"],
                           sdmp_conf,
                           use_cache=True, cache_path=os.path.join(DATA_FOLDER, "SDMPPre"),
                           device=DEVICE)
preprocesser.disp_states()
theta_cand, h_init_theta, X, target =\
    preprocesser.theta_cand, preprocesser.h_init_theta, preprocesser.X, preprocesser.target

print("Initializing the model...")
sdmp = SDMP(X,
            target,
            theta_cand,
            h_init_theta,
            sdmp_conf,
            device=DEVICE,
            verbose=True)

sdmp.load(ARGS_GLOBAL.dict)

## deal with the partial test logic
if train_conf["partial_test"]:
    print("*** Partial test enabled, the testing results is only based "
          f"{train_conf['partial_test_ratio']} of the original test set "
          f"and {train_conf['partial_val_ratio']} of the original val set.***")
    test_size = int(len(g.test_idx) * train_conf["partial_test_ratio"])
    g.test_idx = g.test_idx[:test_size]
    val_size = int(len(g.val_idx) * train_conf["partial_val_ratio"])
    g.val_idx = g.val_idx[:val_size]


"""
Deal with the inductive and partial test during the SDMP
If SDMP is trained with partial test and currently with partial test, 
only the corresponding positions and the training set and val set will be 
computed to save time. 
"""
tic = time()
if sdmp.ThetaT is None:
    print("SDMP from partial test, generating neccessary Theta...")
    if train_conf["partial_test"]:
        print("Detected partial test mode, only generating neccessary Theta...")
        eval_index = torch.cat((g.train_idx, g.val_idx, g.test_idx)).detach().cpu().numpy()
        print(f"Eval size {len(eval_index)}")
        sdmp.compute_ThetaT_from_h(target_nodes=eval_index)
    else:
        print("Full test on MLP. Generating all Theta and saving a cache...")
        sdmp.compute_ThetaT_from_h()
        sdmp.save_ThetaT(ARGS_GLOBAL.dict)

features = sdmp.infer_torch_node_approximal_features()
print(f"Preproccessing finished in {time()-tic:.4f} s.")

# main training
hidden_size = [train_conf['hidden_size']] * train_conf['hidden_layer']
all_acc =  []

for I in range(ARGS_GLOBAL.trials):
    # training loop
    in_size = features.shape[1]
    out_size = g.num_classes

    model = MLP(in_size, hidden_size, out_size, dropout=train_conf['dropout']).to(DEVICE)
    best_model_state = train(DEVICE, features, teacher_emb, g,  model, train_conf)

    with open(os.path.join(MODEL_FOLDER, "state_dict_"+str(I)), "wb") as f:
        f.write(best_model_state)

    model.load_state_dict(pickle.loads(best_model_state))
    test_dataloader = MyLoader(features.to(DEVICE), torch.tensor(g.ndata['label']).to(DEVICE),
                               train_conf["batch_size"], g.test_idx.to(DEVICE))
    acc = evaluate(model, test_dataloader)
    all_acc.append(acc.item())

print(f"Overall f1_micro: {np.mean(all_acc):.4f}$\\pm${np.std(all_acc):.4f}")

# save results
with open(os.path.join(RES_FOLDER, "f1_micro.pkl"), "wb") as fout:
    pickle.dump(all_acc, fout)

# back up important configurations
export_train_conf(os.path.join(RES_FOLDER, 'conf.yml'), train_conf)
if not train_conf["partial_test"]:
    shutil.copyfile(SDMP_THETA_PATH, os.path.join(RES_FOLDER, 'ThetaT.pkl'))
shutil.copyfile(SDMP_H_PATH, os.path.join(RES_FOLDER, 'h_model_stat.pkl'))
shutil.copyfile(SDMP_CONF_PATH, os.path.join(RES_FOLDER, 'dict_conf.yml'))
shutil.copyfile(SDMP_LOG_PATH, os.path.join(RES_FOLDER, 'dict_log.pkl'))
shutil.copyfile(SDMP_LOGTXT_PATH, os.path.join(RES_FOLDER, 'dict_log.txt'))

try:
    shutil.copyfile(GNN_TEACHER_PATH, os.path.join(RES_FOLDER, 'GNN_teacher_emb.npz'))
except:
    pass    
shutil.copyfile(GNN_MODEL_PATH, os.path.join(RES_FOLDER, 'GNN_target_state'))
if sdmp_conf["target_h_mode"] == "internal":
    shutil.copyfile(GNN_CONF_PATH, os.path.join(RES_FOLDER, 'GNN_conf.yml'))
shutil.copyfile(GNN_ACC_PATH, os.path.join(RES_FOLDER, 'GNN_f1.txt'))
shutil.copyfile(GNN_DATA_SPLIT_PATH, os.path.join(RES_FOLDER, 'GNN_data_split_seed.txt'))
